# RUN: %PYTHON %s | FileCheck %s

import gc

from mlir.ir import *
from mlir.dialects._ods_common import equally_sized_accessor


def run(f):
    print("\nTEST:", f.__name__)
    f()
    gc.collect()
    assert Context._get_live_count() == 0


def add_dummy_value():
    return Operation.create(
        "custom.value", results=[IntegerType.get_signless(32)]
    ).result


def testOdsBuildDefaultImplicitRegions():
    class TestFixedRegionsOp(OpView):
        OPERATION_NAME = "custom.test_op"
        _ODS_REGIONS = (2, True)

    class TestVariadicRegionsOp(OpView):
        OPERATION_NAME = "custom.test_any_regions_op"
        _ODS_REGIONS = (2, False)

    with Context() as ctx, Location.unknown():
        ctx.allow_unregistered_dialects = True
        m = Module.create()
        with InsertionPoint(m.body):
            op = TestFixedRegionsOp.build_generic(results=[], operands=[])
            # CHECK: NUM_REGIONS: 2
            print(f"NUM_REGIONS: {len(op.regions)}")
            # Including a regions= that matches should be fine.
            op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
            print(f"NUM_REGIONS: {len(op.regions)}")
            # Reject greater than.
            try:
                op = TestFixedRegionsOp.build_generic(
                    results=[], operands=[], regions=3
                )
            except ValueError as e:
                # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
                print(f"ERROR:{e}")
            # Reject less than.
            try:
                op = TestFixedRegionsOp.build_generic(
                    results=[], operands=[], regions=1
                )
            except ValueError as e:
                # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
                print(f"ERROR:{e}")

            # If no regions specified for a variadic region op, build the minimum.
            op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
            # CHECK: DEFAULT_NUM_REGIONS: 2
            print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
            # Should also accept an explicit regions= that matches the minimum.
            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2)
            # CHECK: EQ_NUM_REGIONS: 2
            print(f"EQ_NUM_REGIONS: {len(op.regions)}")
            # And accept greater than minimum.
            # Should also accept an explicit regions= that matches the minimum.
            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3)
            # CHECK: GT_NUM_REGIONS: 3
            print(f"GT_NUM_REGIONS: {len(op.regions)}")
            # Should reject less than minimum.
            try:
                op = TestVariadicRegionsOp.build_generic(
                    results=[], operands=[], regions=1
                )
            except ValueError as e:
                # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
                print(f"ERROR:{e}")


run(testOdsBuildDefaultImplicitRegions)


def testOdsBuildDefaultNonVariadic():
    class TestOp(OpView):
        OPERATION_NAME = "custom.test_op"

    with Context() as ctx, Location.unknown():
        ctx.allow_unregistered_dialects = True
        m = Module.create()
        with InsertionPoint(m.body):
            v0 = add_dummy_value()
            v1 = add_dummy_value()
            t0 = IntegerType.get_signless(8)
            t1 = IntegerType.get_signless(16)
            op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
            # CHECK: %[[V0:.+]] = "custom.value"
            # CHECK: %[[V1:.+]] = "custom.value"
            # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
            # CHECK-NOT: operandSegmentSizes
            # CHECK-NOT: resultSegmentSizes
            # CHECK-SAME: : (i32, i32) -> (i8, i16)
            print(m)


run(testOdsBuildDefaultNonVariadic)


def testOdsBuildDefaultSizedVariadic():
    class TestOp(OpView):
        OPERATION_NAME = "custom.test_op"
        _ODS_OPERAND_SEGMENTS = [1, -1, 0]
        _ODS_RESULT_SEGMENTS = [-1, 0, 1]

    with Context() as ctx, Location.unknown():
        ctx.allow_unregistered_dialects = True
        m = Module.create()
        with InsertionPoint(m.body):
            v0 = add_dummy_value()
            v1 = add_dummy_value()
            v2 = add_dummy_value()
            v3 = add_dummy_value()
            t0 = IntegerType.get_signless(8)
            t1 = IntegerType.get_signless(16)
            t2 = IntegerType.get_signless(32)
            t3 = IntegerType.get_signless(64)
            # CHECK: %[[V0:.+]] = "custom.value"
            # CHECK: %[[V1:.+]] = "custom.value"
            # CHECK: %[[V2:.+]] = "custom.value"
            # CHECK: %[[V3:.+]] = "custom.value"
            # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
            # CHECK-SAME: operandSegmentSizes = array<i32: 1, 2, 1>
            # CHECK-SAME: resultSegmentSizes = array<i32: 2, 1, 1>
            # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
            op = TestOp.build_generic(
                results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3]
            )

            # Now test with optional omitted.
            # CHECK: "custom.test_op"(%[[V0]])
            # CHECK-SAME: operandSegmentSizes = array<i32: 1, 0, 0>
            # CHECK-SAME: resultSegmentSizes = array<i32: 0, 0, 1>
            # CHECK-SAME: (i32) -> i64
            op = TestOp.build_generic(
                results=[None, None, t3], operands=[v0, None, None]
            )
            print(m)

            # And verify that errors are raised for None in a required operand.
            try:
                op = TestOp.build_generic(
                    results=[None, None, t3], operands=[None, None, None]
                )
            except ValueError as e:
                # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
                print(f"OPERAND_CAST_ERROR:{e}")

            # And verify that errors are raised for None in a required result.
            try:
                op = TestOp.build_generic(
                    results=[None, None, None], operands=[v0, None, None]
                )
            except ValueError as e:
                # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
                print(f"RESULT_CAST_ERROR:{e}")

            # Variadic lists with None elements should reject.
            try:
                op = TestOp.build_generic(
                    results=[None, None, t3], operands=[v0, [None], None]
                )
            except ValueError as e:
                # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
                print(f"OPERAND_LIST_CAST_ERROR:{e}")
            try:
                op = TestOp.build_generic(
                    results=[[None], None, t3], operands=[v0, None, None]
                )
            except ValueError as e:
                # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
                print(f"RESULT_LIST_CAST_ERROR:{e}")


run(testOdsBuildDefaultSizedVariadic)


def testOdsBuildDefaultCastError():
    class TestOp(OpView):
        OPERATION_NAME = "custom.test_op"

    with Context() as ctx, Location.unknown():
        ctx.allow_unregistered_dialects = True
        m = Module.create()
        with InsertionPoint(m.body):
            v0 = add_dummy_value()
            v1 = add_dummy_value()
            t0 = IntegerType.get_signless(8)
            t1 = IntegerType.get_signless(16)
            try:
                op = TestOp.build_generic(results=[t0, t1], operands=[None, v1])
            except ValueError as e:
                # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
                print(f"ERROR: {e}")
            try:
                op = TestOp.build_generic(results=[t0, None], operands=[v0, v1])
            except ValueError as e:
                # CHECK: Result 1 of operation "custom.test_op" must be a Type
                print(f"ERROR: {e}")


run(testOdsBuildDefaultCastError)


def testOdsEquallySizedAccessor():
    class TestOpMultiResultSegments(OpView):
        OPERATION_NAME = "custom.test_op"
        _ODS_REGIONS = (1, True)

    with Context() as ctx, Location.unknown():
        ctx.allow_unregistered_dialects = True
        m = Module.create()
        with InsertionPoint(m.body):
            v = add_dummy_value()
            ts = [IntegerType.get_signless(i * 8) for i in range(4)]

            op = TestOpMultiResultSegments.build_generic(
                results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
            )
            start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 0)
            # CHECK: start: 1, elements_per_group: 1
            print(f"start: {start}, elements_per_group: {elements_per_group}")
            # CHECK: i8
            print(op.results[start].type)

            start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 1)
            # CHECK: start: 2, elements_per_group: 1
            print(f"start: {start}, elements_per_group: {elements_per_group}")
            # CHECK: i16
            print(op.results[start].type)


run(testOdsEquallySizedAccessor)


def testOdsEquallySizedAccessorMultipleSegments():
    class TestOpMultiResultSegments(OpView):
        OPERATION_NAME = "custom.test_op"
        _ODS_REGIONS = (1, True)
        _ODS_RESULT_SEGMENTS = [0, -1, -1]

    def types(lst):
        return [e.type for e in lst]

    with Context() as ctx, Location.unknown():
        ctx.allow_unregistered_dialects = True
        m = Module.create()
        with InsertionPoint(m.body):
            v = add_dummy_value()
            ts = [IntegerType.get_signless(i * 8) for i in range(7)]

            op = TestOpMultiResultSegments.build_generic(
                results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
                operands=[v],
            )
            start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 0)
            # CHECK: start: 1, elements_per_group: 3
            print(f"start: {start}, elements_per_group: {elements_per_group}")
            # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
            print(types(op.results[start : start + elements_per_group]))

            start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 1)
            # CHECK: start: 4, elements_per_group: 3
            print(f"start: {start}, elements_per_group: {elements_per_group}")
            # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
            print(types(op.results[start : start + elements_per_group]))


run(testOdsEquallySizedAccessorMultipleSegments)
